{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWA5KfG8QguH"
      },
      "source": [
        "Notebook for score normalization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "arboOzAbQkUa"
      },
      "outputs": [],
      "source": [
        "# Imports\n",
        "import pandas as pd\n",
        "import pickle\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vHDvJkFLHPCh"
      },
      "outputs": [],
      "source": [
        "# May be needed for numerical stability, with GPT2 especially.\n",
        "def logsumexp(x):\n",
        "    c = x.max()\n",
        "    return c + np.log(np.sum(np.exp(x - c)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "alqHmsWSDMG-"
      },
      "source": [
        "Conditional Probability Calculations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iBKFHwqgDPVy"
      },
      "source": [
        "T5 Calculations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0kZCN9wbQmVA"
      },
      "outputs": [],
      "source": [
        "# Load appropriate scores.\n",
        "PICKLE_PATH = 't5_conditional_adjacent.pickle'\n",
        "with open(PICKLE_PATH, 'rb') as pfile:\n",
        "  obj = pickle.load(pfile)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VTFkZ0XuqkjI"
      },
      "outputs": [],
      "source": [
        "# Get the set of prompt templates.\n",
        "prompt_templates = [\n",
        "    \"My preferred words are {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"My preferred words are {here}, \u003cextra_id_0\u003e, and tree.\",\n",
        "    \"She wrote the words {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"She wrote the words {here} and \u003cextra_id_0\u003e in her notebook.\",\n",
        "    \"She wrote the words {here}, \u003cextra_id_0\u003e, and cabbage.\",\n",
        "    \"I wrote the words {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"I wrote the words {here} and \u003cextra_id_0\u003e in my notebook.\",\n",
        "    \"I wrote the words {here}, \u003cextra_id_0\u003e, and cabbage.\",\n",
        "    \"He wrote the words {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"He wrote the words {here} and \u003cextra_id_0\u003e in his notebook.\",\n",
        "    \"He wrote the words {here}, \u003cextra_id_0\u003e, and cabbage.\",\n",
        "    \"We wrote the words {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"We wrote the words {here} and \u003cextra_id_0\u003e in our notebook.\",\n",
        "    \"We wrote the words {here}, \u003cextra_id_0\u003e, and cabbage.\",\n",
        "    \"Mary wrote the words {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"Mary wrote the words {here} and \u003cextra_id_0\u003e in her notebook.\",\n",
        "    \"Mary wrote the words {here}, \u003cextra_id_0\u003e, and cabbage.\",\n",
        "    \"Please spell {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"Please spell {here}, \u003cextra_id_0\u003e, and panther.\",\n",
        "    \"Please spell {here} and \u003cextra_id_0\u003e correctly.\",\n",
        "    \"Say {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"Say {here}, \u003cextra_id_0\u003e, and tapestry.\",\n",
        "    \"Say {here} and \u003cextra_id_0\u003e again.\",\n",
        "    \"The first words on the list were {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"The first words on the list were {here}, \u003cextra_id_0\u003e, and oligarchy.\",\n",
        "    \"The easiest words on the list were {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"The easiest words on the list were {here}, \u003cextra_id_0\u003e, and oligarchy.\",\n",
        "    \"The hardest words on the list were {here} and \u003cextra_id_0\u003e.\",\n",
        "    \"The hardest words on the list were ${here}, \u003cextra_id_0\u003e, and oligarchy.\",\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fKzrhdZj-hgE"
      },
      "outputs": [],
      "source": [
        "# Number of prompt templates.\n",
        "nprompts = len(prompt_templates)\n",
        "# Instances per prompt template.\n",
        "limit = 16028\n",
        "\n",
        "# Normalized Conditionals.\n",
        "grand_us_us = [] \n",
        "grand_us_uk = []\n",
        "grand_uk_us = []\n",
        "grand_uk_uk = []\n",
        "\n",
        "# Consistency preferences.\n",
        "grand_con_us = []\n",
        "grand_con_uk = []\n",
        "\n",
        "# Loop over prompts.\n",
        "for i, xp in enumerate(extra_prompts):\n",
        "  us_con = []\n",
        "  uk_con = []\n",
        "  total_us_us = 0\n",
        "  total_us_uk = 0\n",
        "  total_uk_us = 0\n",
        "  total_uk_uk = 0\n",
        "  for us_us, us_uk, uk_us, uk_uk in zip(obj['us_us'][i*limit:(i+1)*limit], obj['us_uk'][i*limit:(i+1)*limit], obj['uk_us'][i*limit:(i+1)*limit], obj['uk_uk'][i*limit:(i+1)*limit]):\n",
        "\n",
        "    # Check preferences.\n",
        "    if us_us \u003e= us_uk:\n",
        "      us_con.append(1)\n",
        "    else:\n",
        "      us_con.append(0)\n",
        "\n",
        "    if uk_uk \u003e= uk_us:\n",
        "      uk_con.append(1)\n",
        "    else:\n",
        "      uk_con.append(0)\n",
        "\n",
        "    us_us, us_uk, uk_us, uk_uk = np.exp(us_us), np.exp(us_uk), np.exp(uk_us), np.exp(uk_uk)\n",
        "    us_norm = us_us + us_uk\n",
        "    uk_norm = uk_us + uk_uk\n",
        "    us_us, us_uk = us_us/us_norm, us_uk/us_norm\n",
        "    uk_us, uk_uk = uk_us/uk_norm, uk_uk/uk_norm\n",
        "    total_us_us += us_us\n",
        "    total_us_uk += us_uk\n",
        "    total_uk_us += uk_us\n",
        "    total_uk_uk += uk_uk\n",
        "  total = total_us_us + total_us_uk + total_uk_us + total_uk_uk\n",
        "  us_total = total_us_us + total_us_uk\n",
        "  uk_total = total_uk_us + total_uk_uk\n",
        "  normed_us_us = total_us_us/us_total\n",
        "  normed_us_uk = total_us_uk/us_total\n",
        "  normed_uk_us = total_uk_us/uk_total\n",
        "  normed_uk_uk = total_uk_uk/uk_total\n",
        "\n",
        "  # Print some statistics per prompt template.\n",
        "  print(xp)\n",
        "  print(\"-- US UK\")\n",
        "  print(\"US\", normed_us_us, normed_us_uk)\n",
        "  print(\"UK\", normed_uk_us, normed_uk_uk)\n",
        "\n",
        "  # Add to grand total.\n",
        "  grand_us_us.append(normed_us_us)\n",
        "  grand_us_uk.append(normed_us_uk)\n",
        "  grand_uk_us.append(normed_uk_us)\n",
        "  grand_uk_uk.append(normed_uk_uk)\n",
        "\n",
        "  # Print consistency per prompt.\n",
        "  print(\"================\")\n",
        "  print('us_con', 'uk_con')\n",
        "  print(np.mean(us_con), np.mean(uk_con))\n",
        "  print()\n",
        "  \n",
        "  # Aggregate\n",
        "  grand_con_us.append(np.mean(us_con))\n",
        "  grand_con_uk.append(np.mean(uk_con))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "omZyZis4kuc2"
      },
      "outputs": [],
      "source": [
        "# Print aggregate stats.\n",
        "print('us_us', np.mean(grand_us_us), np.std(grand_us_us))\n",
        "print('us_uk', np.mean(grand_us_uk), np.std(grand_us_uk))\n",
        "print('uk_us', np.mean(grand_uk_us), np.std(grand_uk_us))\n",
        "print('uk_uk', np.mean(grand_uk_uk), np.std(grand_uk_uk))\n",
        "\n",
        "print('us_con', np.mean(grand_con_us))\n",
        "print('uk_con', np.mean(grand_con_uk))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6WSjOJswhu7-"
      },
      "source": [
        "GPT2 Calculations"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-BsH6vfGwUzN"
      },
      "outputs": [],
      "source": [
        "# Load appropriate scores.\n",
        "PICKLE_PATH = 'gpt2_adjacent_scores.pickle'\n",
        "with open(PICKLE_PATH, 'rb') as pfile:\n",
        "  obj = pickle.load(pfile)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zCQ1XuzEQ7bV"
      },
      "outputs": [],
      "source": [
        "# Scoring regime type (0 = to end of target, 1 = to end of sentence (EOS), 2 = full sentence \"joint\")\n",
        "tp = 1\n",
        "\n",
        "us_con = []\n",
        "uk_con = []\n",
        "total_us_us = 0\n",
        "total_us_uk = 0\n",
        "total_uk_us = 0\n",
        "total_uk_uk = 0\n",
        "for us_us, us_uk, uk_us, uk_uk in zip(obj['us_us'], obj['us_uk'], obj['uk_us'], obj['uk_uk']):\n",
        "\n",
        "  us_us = us_us[tp]\n",
        "  us_uk = us_uk[tp]\n",
        "  uk_us = uk_us[tp]\n",
        "  uk_uk = uk_uk[tp]\n",
        "\n",
        "  # Check preferences.\n",
        "  if us_us \u003e= us_uk:\n",
        "    us_con.append(1)\n",
        "  else:\n",
        "    us_con.append(0)\n",
        "\n",
        "  if uk_uk \u003e= uk_us:\n",
        "    uk_con.append(1)\n",
        "  else:\n",
        "    uk_con.append(0)\n",
        "\n",
        "  x = np.array([us_us, us_uk])\n",
        "  nx = np.exp(x - logsumexp(x))\n",
        "  us_us, us_uk = nx[0], nx[1]\n",
        "\n",
        "  x = np.array([uk_us, uk_uk])\n",
        "  nx = np.exp(x - logsumexp(x))\n",
        "  uk_us, uk_uk = nx[0], nx[1]\n",
        "\n",
        "  total_us_us += us_us\n",
        "  total_us_uk += us_uk\n",
        "  total_uk_us += uk_us\n",
        "  total_uk_uk += uk_uk\n",
        "total = total_us_us + total_us_uk + total_uk_us + total_uk_uk\n",
        "us_total = total_us_us + total_us_uk\n",
        "uk_total = total_uk_us + total_uk_uk\n",
        "normed_us_us = total_us_us/us_total\n",
        "normed_us_uk = total_us_uk/us_total\n",
        "normed_uk_us = total_uk_us/uk_total\n",
        "normed_uk_uk = total_uk_uk/uk_total\n",
        "\n",
        "# Print aggreate conditional probabilities and consensus values.\n",
        "print('us_us', 'us_uk', 'uk_us','uk_uk' )\n",
        "print(normed_us_us, normed_us_uk, normed_uk_us, normed_uk_uk)\n",
        "print('us_con', 'uk_con')\n",
        "print(np.mean(us_con), np.mean(uk_con))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EsfIhMH5z-4R"
      },
      "source": [
        "Log-likelihood Ratio Calculations for T5"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XVkQJl9Rz-Kc"
      },
      "outputs": [],
      "source": [
        "# Load appropriate scores.\n",
        "PICKLE_PATH = 't5_joint_adjacent.pickle'\n",
        "with open(PICKLE_PATH, 'rb') as pfile:\n",
        "  obj = pickle.load(pfile)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DU8Wsl0l0iOz"
      },
      "outputs": [],
      "source": [
        "total_llr = 0\n",
        "total_count = 0\n",
        "for us_us, us_uk, uk_us, uk_uk in zip(obj['us_us'], obj['us_uk'], obj['uk_us'], obj['uk_uk']):\n",
        "  x = np.array([us_us, us_uk, uk_us, uk_uk])\n",
        "  nx = np.exp(x - logsumexp(x))\n",
        "  us_us, us_uk, uk_us, uk_uk = nx[0], nx[1], nx[2], nx[3]\n",
        "\n",
        "  us_prompt = us_us + us_uk\n",
        "  us_target = us_us + uk_us\n",
        "  uk_prompt = uk_us + uk_uk\n",
        "  uk_target = us_uk + uk_uk\n",
        "\n",
        "  # Per-sample LLR calculation\n",
        "  p1 = us_us * np.log(us_us/(us_prompt*us_target))\n",
        "  p2 = us_uk * np.log(us_uk/(us_prompt*uk_target))\n",
        "  p3 = uk_us * np.log(uk_us/(uk_prompt*us_target))\n",
        "  p4 = uk_uk * np.log(uk_uk/(uk_prompt*uk_target))\n",
        "  llr = (p1+p2+p3+p4)\n",
        "  total_llr += llr\n",
        "  total_count += 1\n",
        "# Print aggregate LLR\n",
        "print('average_llr')\n",
        "print(total_llr/total_count)"
      ]
    },
    {
      "metadata": {},
      "cell_type": "markdown",
      "source": [
        "Licensed under the Apache License, Version 2.0"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1d652eIMyjSWVb_U7ab2plrFSUWhtYxh9",
          "timestamp": 1683233877599
        },
        {
          "file_id": "1G2D2g57h8WNwLqw1EnpkjRVU9Z-2hJqL",
          "timestamp": 1683232266596
        },
        {
          "file_id": "1-4MlbcP4_aTURTRnJNNheQ2FYCeCsFMy",
          "timestamp": 1672870502509
        },
        {
          "file_id": "139UPIq_HN6orC7JkwD2tA14gem9GkCL0",
          "timestamp": 1666213148860
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
